import json

import config
import numpy as np
import matplotlib.pyplot as plt


def paint(all_data):
    epoch_loss = []
    epoch_loss_mean = []
    all_loss = []
    epoch = 0

    for i in range(len(all_data)):
        single = all_data[i]
        try:
            parsed = json.loads(single)
        except Exception as e:
            print("第{}行不是需要的数据，错误原因->{}".format(i + 1, e))
            continue
        data_type = parsed['type']
        data = parsed['data']
        if data_type == 'epoch_loss':
            epoch_loss.extend(data)
            epoch_loss_mean.append(np.mean(data))
            epoch = epoch + 1
        else:
            all_loss.extend(data)
    plt.plot(epoch_loss)
    plt.ylabel('epoch_loss')
    plt.xlabel('epoch = {}'.format(epoch))
    plt.show()

    plt.plot(epoch_loss_mean)
    plt.ylabel('epoch_loss_mean')
    plt.xlabel('epoch = {}'.format(epoch))
    plt.show()

    plt.plot(all_loss)
    plt.ylabel('all_loss')
    plt.xlabel('epoch = {}'.format(epoch))
    plt.show()

def main():
    fs = open(config.log_path, encoding='utf-8', mode='r')
    all_data = fs.read()
    lines = all_data.split('\n')
    paint(lines)


main()
